Skip to content

[Quantization] Support Quark W8A8 INT8 MoE inference#36320

Merged
tjtanaa merged 5 commits intovllm-project:mainfrom
JoursBleu:feat/quark-w8a8-int8-moe
Apr 9, 2026
Merged

[Quantization] Support Quark W8A8 INT8 MoE inference#36320
tjtanaa merged 5 commits intovllm-project:mainfrom
JoursBleu:feat/quark-w8a8-int8-moe

Conversation

@JoursBleu
Copy link
Copy Markdown
Contributor

@JoursBleu JoursBleu commented Mar 7, 2026

Purpose

MoE models quantized by AMD Quark with W8A8 INT8 (per-channel weight + per-token dynamic activation) cannot be loaded in vLLM. For example, quantizing MiniMax-M2.1 (456B MoE) with Quark's ptpc_int8 scheme produces a model that fails at startup with:

  1. quark.py: _get_scheme_from_config() only recognizes static per-tensor W8A8 INT8 via _is_static_tensor_w8a8, missing the dynamic per-token + per-channel weight config → RuntimeError("Unsupported quantization scheme")
  2. quark_moe.py: No INT8 MoE method exists (only Fp8 and OCP_MX) → RuntimeError("Unsupported FusedMoe scheme")
  3. fused_moe/utils.py: _int8_quantize() hard-asserts per_act_token when block_shape is None, blocking per-tensor static/dynamic INT8 paths

This PR adds:

  • _is_dynamic_per_token_w8a8() detection and routing to QuarkW8A8Int8(is_static_input_scheme=False) in quark.py
  • QuarkW8A8Int8MoEMethod in quark_moe.py supporting both per-tensor and per-channel weight scales
  • Defensive branching in _int8_quantize() for per-token / static per-tensor / dynamic per-tensor paths
  • Fix trust_remote_code=FalseTrue in get_config() call (required for custom models like MiniMax-M2.1)

Test Plan

# Start vLLM server with Quark W8A8 INT8 MoE model
VLLM_WORKER_MULTIPROC_METHOD=spawn vllm serve /path/to/MiniMax-M2.1-quark-ptpc-int8 \
    --tensor-parallel-size 8 --trust-remote-code --max-model-len 8192

# MMLU via lm-evaluation-harness
lm_eval --model local-completions \
  --model_args "model=MiniMax-M2.1-quark-ptpc-int8,base_url=http://localhost:8000/v1/completions,num_concurrent=50,tokenized_requests=False" \
  --tasks mmlu --num_fewshot 5

# GSM8K via vLLM's built-in eval script
python tests/evals/gsm8k/gsm8k_eval.py --port 8000 --num-shots 8

Test Result

Tested on MiniMax-M2.1 (456B MoE) quantized with Quark ptpc_int8, served with vLLM on 8x GPU:

Benchmark BF16 (official) Quark W8A8 INT8 Recovery
MMLU (5-shot) 86.2% (5-shot) 85.48% 99.2%
GSM8K (8-shot) 92.0% (8-shot) 93.7% 101.8%

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Mar 7, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for Quark W8A8 INT8 MoE inference, which is a valuable addition. The changes are well-structured and address the missing functionality for this quantization scheme. My review focuses on improving code clarity and maintainability. I've pointed out a few instances of misleading documentation and variable names, as well as one case of unreachable code. Addressing these points will enhance the long-term quality of the codebase.

Note: Security Review did not run due to the size of the PR.

Comment thread vllm/model_executor/layers/fused_moe/utils.py Outdated
Comment thread vllm/model_executor/layers/quantization/quark/quark.py Outdated
Comment thread vllm/model_executor/layers/quantization/quark/quark_moe.py Outdated
@JoursBleu JoursBleu force-pushed the feat/quark-w8a8-int8-moe branch from 63483a5 to f53e7e3 Compare March 8, 2026 05:24
@JoursBleu JoursBleu marked this pull request as ready for review March 8, 2026 05:32
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @JoursBleu.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify Bot added the needs-rebase label Mar 13, 2026
@JoursBleu JoursBleu force-pushed the feat/quark-w8a8-int8-moe branch from f53e7e3 to 1125a4e Compare March 27, 2026 07:33
@mergify mergify Bot removed the needs-rebase label Mar 27, 2026
@JoursBleu JoursBleu force-pushed the feat/quark-w8a8-int8-moe branch from 1125a4e to 526ba9c Compare March 27, 2026 07:45
Copy link
Copy Markdown
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to inline comments, could you add a small model for integration test? otherwise looks good.

Comment thread vllm/model_executor/layers/quantization/quark/quark.py Outdated
Comment thread vllm/model_executor/layers/quantization/quark/quark_moe.py Outdated
Comment thread vllm/model_executor/layers/fused_moe/utils.py Outdated
@JoursBleu JoursBleu force-pushed the feat/quark-w8a8-int8-moe branch from 526ba9c to c619159 Compare March 29, 2026 10:23
@JoursBleu
Copy link
Copy Markdown
Contributor Author

All review comments addressed, rebased onto latest main, and added integration test with a tiny MoE model. GSM8K 8-shot accuracy re-verified. @BowenBao

Signed-off-by: kangletian <Letian.Kang@amd.com>
@JoursBleu JoursBleu force-pushed the feat/quark-w8a8-int8-moe branch from c619159 to 2e825b2 Compare March 30, 2026 03:27
Copy link
Copy Markdown
Contributor

@BowenBao BowenBao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, please fix pre-commit issues.

@tjtanaa tjtanaa added rocm Related to AMD ROCm ready ONLY add when PR is ready to merge/full CI is needed labels Mar 30, 2026
@github-project-automation github-project-automation Bot moved this to Todo in AMD Mar 30, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 30, 2026

Hi @JoursBleu, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Collaborator

@tjtanaa tjtanaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but pre-commit has to be fixed before merging.

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Mar 31, 2026

Hi @JoursBleu, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

JoursBleu and others added 2 commits March 31, 2026 09:23
@JoursBleu
Copy link
Copy Markdown
Contributor Author

This PR should be ready to merge. @tjtanaa @BowenBao

@tjtanaa tjtanaa enabled auto-merge (squash) April 9, 2026 15:58
@tjtanaa tjtanaa merged commit 827268e into vllm-project:main Apr 9, 2026
70 of 71 checks passed
@github-project-automation github-project-automation Bot moved this from Todo to Done in AMD Apr 9, 2026
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
)

Signed-off-by: kangletian <Letian.Kang@amd.com>
whk-lab pushed a commit to whk-lab/vllm that referenced this pull request Apr 23, 2026
)

Signed-off-by: kangletian <Letian.Kang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants